Some times I get this error:

XlaRuntimeError Traceback (most recent call last) Cell In[5], line 56 44 parameterized_pdf = Graph.pmf_from_graph( 45 graph, 46 # discrete=False, 47 ) 49 # test_theta = jnp.array(true_theta) 50 # test_times = jnp.array([0.5, 1.0, 1.5]) 51 # test_pdf = pparameterized_pdfdf(test_theta, observed_data) (…) 54 55 #learning_rate = 0.01 —> 56 svgd = SVGD(parameterized_pdf, observed_data, theta_dim=len(true_theta), 57 n_iterations=n_iterations, 58 59 # learning_rate=learning_rate, 60 verbose=True, 61 )
Core Computation: 5614.850s ( 6.6%) JAX Overhead: 60470.898s ( 70.9%) SVGD Algorithm: 8773.522s ( 10.3%) NumPy Operations: 0.000s ( 0.0%) Other: 9936.396s ( 11.6%)
JAX Overhead Breakdown:

vmap/pmap 15285.256s ( 25.3% of JAX, 17.9% total) Other JAX 14738.447s ( 24.4% of JAX, 17.3% total) Dispatch 12227.066s ( 20.2% of JAX, 14.3% total) Callbacks/FFI 9097.319s ( 15.0% of JAX, 10.7% total) Primitives 5937.572s ( 9.8% of JAX, 7.0% total) Tracing 3057.867s ( 5.1% of JAX, 3.6% total) Array ops 127.371s ( 0.2% of JAX, 0.1% total)

Top Computation Functions:

  1. pmf_function 2808.155s ( 3.3%) [ 840000 calls]
  2. _compute_pmf_from_ctypes 2806.695s ( 3.3%) [ 840000 calls]

Top JAX Dispatch Functions:

  1. call_wrapped 3057.210s ( 3.6%) [ 2125 calls]
  2. bind 3056.626s ( 3.6%) [ 4610 calls]
  3. _true_bind 3056.620s ( 3.6%) [ 4610 calls]
  4. bind_with_trace 3056.610s ( 3.6%) [ 4628 calls]

Top JAX vmap/pmap Functions:

  1. vmap_f 3057.616s ( 3.6%) [ 2000 calls]
  2. _batch_outer 3057.158s ( 3.6%) [ 2000 calls]
  3. _batch_inner 3057.107s ( 3.6%) [ 2000 calls]
  4. flatten_fun_for_vmap 3056.901s ( 3.6%) [ 2000 calls]
  5. _pjit_batcher 3056.473s ( 3.6%) [ 2000 calls]

Top JAX Array Functions:

  1. device_put 127.371s ( 0.1%) [ 840001 calls]

Top JAX Primitive Functions:

  1. process_primitive 3056.528s ( 3.6%) [ 2100 calls]
  2. process_primitive 2881.044s ( 3.4%) [ 2000 calls]

Top JAX Callback/FFI Functions:

  1. _wrapped_callback 3037.308s ( 3.6%) [ 840000 calls]
  2. _callback 3030.319s ( 3.6%) [ 840000 calls]
  3. pure_callback_impl 3029.693s ( 3.6%) [ 840000 calls]

Top SVGD Functions:

  1. svgd_step 3055.878s ( 3.6%) [ 2000 calls]
  2. fit 2858.822s ( 3.4%) [ 3 calls]
  3. run_svgd 2858.821s ( 3.4%) [ 3 calls]

Optimization Recommendations:

• HIGH vmap/pmap overhead (17.9%): - vmap overhead is normal for vectorized operations - Consider using explicit loops if vmap is over small batches - Check if batch size can be increased

================================================================================

:::
:::


::: {#f1857753 .cell execution_count=7}
``` {.python .cell-code}
svgd.plot_convergence() ;

:::

svgd.plot_convergence()
(<Figure size 700x300 with 2 Axes>,
 array([<Axes: title={'center': 'Mean Convergence'}, xlabel='SVGD Iteration', ylabel='Posterior Mean'>,
        <Axes: title={'center': 'Std Convergence'}, xlabel='SVGD Iteration', ylabel='Posterior Std'>],
       dtype=object))

! ls ../../galleries/
_build                   c_api                    numpy.css

_extensions              custom-dark.scss         numpy.theme

_freeze                  custom.scss              objects.txt

_inv                     galleries                pages

_quarto.yml              index.qmd                r_api

api                      logo.png                 styles.css

autodoc.mustache         numpy-dark.theme

banner.png               numpy-navbar-sidebar.css
/Users/kmt/PtDAlgorithms/.pixi/envs/default/lib/python3.13/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()
fig, axes = svgd.plot_trace()
fig.savefig('../../galleries/examples/images/svgd_convergence.webp', format='webp')
60
60
60

svgd.summary()
======================================================================
SVGD Inference Summary
======================================================================
Number of particles: 60
Number of iterations: 1000
Parameter dimension: 3

Posterior estimates:
  θ_0: 136.4124 ± 128.4685
       95% CI: [-29.4463, 352.3441]
  θ_1: 1633.8782 ± 986.7968
       95% CI: [-744.5340, 3198.3957]
  θ_2: 91.2288 ± 96.7373
       95% CI: [-107.4353, 276.1241]
======================================================================
#svgd.animate(thin=50)
svgd.plot_pairwise(true_theta=true_theta,
                   # param_names=['jump', 'flood_left', 'flood_right'],
    show_transformed=True,
                   ) ;                   
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 1
----> 1 svgd.plot_pairwise(true_theta=true_theta,
      2                    # param_names=['jump', 'flood_left', 'flood_right'],
      3     show_transformed=True,
      4                    ) ;                   

File ~/PtDAlgorithms/src/ptdalgorithms/svgd.py:1794, in plot_pairwise(self, true_theta, param_names, figsize, save_path, show_transformed)
   1773 def plot_pairwise(self, true_theta=None, param_names=None,
   1774                  figsize=None, save_path=None, show_transformed=True):
   1775     """
   1776     Plot pairwise scatter plots for all parameter pairs.
   1777 
   1778     Parameters
   1779     ----------
   1780     true_theta : array_like, optional
   1781         True parameter values (if known) to overlay on plot
   1782     param_names : list of str, optional
   1783         Names for each parameter dimension
   1784     figsize : tuple, optional
   1785         Figure size (width, height)
   1786     save_path : str, optional
   1787         Path to save the plot
   1788     show_transformed : bool, default=True
   1789         If True, show transformed (constrained) parameter values.
   1790         If False, show untransformed (unconstrained) values.
   1791         Only relevant when using parameter transformations.
   1792 
   1793     Returns
-> 1794     -------
   1795     fig, axes
   1796         Matplotlib figure and axes objects
   1797     """
   1798     if not self.is_fitted:
   1799         raise RuntimeError("Must call fit() before plotting")

RuntimeError: Must call fit() before plotting
anim = svgd.animate_pairwise(
    true_theta=[2.0, 3.0, 2.0],
    param_names=['jump', 'flood_left', 'flood_right'],
    thin=20,
    show_transformed=True,
)
anim  # Display in Jupyter
[INFO] Animation.save using <class 'matplotlib.animation.HTMLWriter'>
[INFO] figure size in inches has been adjusted from 9.0 x 6.8999999999999995 to 9.0 x 6.9
results = svgd.get_results()
results.keys()
dict_keys(['particles', 'theta_mean', 'theta_std', 'history'])